{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5c7bc97f-89b4-4275-baf0-825ebda0bae7",
   "metadata": {},
   "source": [
    "# Open Graph Benchmark (OGB) Node Prediction for Micrososoft Academic Graph (OGBN-MAG)\n",
    "\n",
    "The [Open Graph Benchmark (OGB)](https://ogb.stanford.edu/) is a collection of realistic, large-scale, and diverse benchmark datasets for machine learning on graphs.\n",
    "\n",
    "This notebook demonstrate a [Graph Neural Network ([GNN])](https://en.wikipedia.org/wiki/Graph_neural_network) based on [TF-GNN](https://github.com/tensorflow/gnn) [OGBNMAG tutorial](https://colab.research.google.com/github/tensorflow/gnn/blob/master/examples/notebooks/ogbn_mag_e2e.ipynb), but using GoMLX.\n",
    "\n",
    "The task [OGBN-MAG is described here](https://ogb.stanford.edu/docs/nodeprop/#ogbn-mag). This demo comes with a library that does the downloading, parsing and converting of the data to tensors for fast use.\n",
    "\n",
    "The model is experimental, but it includes a basic GNN library that can be used for different projects.\n",
    "\n",
    "**EXPERIMENTAL**, it has been used only for OGBN-MAG still.\n",
    "\n",
    "See also [OGBN-MAG Leaderboard](https://ogb.stanford.edu/docs/leader_nodeprop/#ogbn-papers100M) -- take with a grain of salt because different models use different tricks that may be considered leaking, or using extra data from the outside (so more data), and in some cases are very overfit to the task. But still, it's a fun dataset to work with.\n",
    "\n",
    "See the subdirectory `demo` for a command line of the trainer (that can be run on the cloud somewhere) and it will save datapoints that can be plotted in the notebook later.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b357050e-2756-465f-9e81-c83c49b6a84b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\t- Replace rule for module \"github.com/janpfeifer/gonb\" to local directory \"/home/janpf/Projects/gonb\" already exists.\n",
      "\t- Replace rule for module \"github.com/gomlx/gomlx\" to local directory \"/home/janpf/Projects/gomlx\" already exists.\n"
     ]
    }
   ],
   "source": [
    "!*rm -f go.work && go work init && go work use . \"${HOME}/Projects/gomlx\" \"${HOME}/Projects/gonb\"\n",
    "%goworkfix"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "598ddc21-be30-4883-ad85-e0f9199ef0d7",
   "metadata": {},
   "source": [
    "## Downloading Dataset\n",
    "\n",
    "The method `mag.Download()` will download the dataset to the data directory, if not yet downloaded.\n",
    "It then converts the dataset to tensors, which are then available for use.\n",
    "\n",
    "The tensor is then saved for faster access. After saved, the next call to `mag.Download()` will take 1/2s."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "fb544cc2-b70b-4975-acda-fd39ccfb0f94",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Elapsed: 1.973077748s\n"
     ]
    }
   ],
   "source": [
    "import (\n",
    "    \"flag\"\n",
    "    . \"github.com/gomlx/gomlx/pkg/core/graph\"\n",
    "    mag \"github.com/gomlx/gomlx/examples/ogbnmag\"\n",
    "    \"github.com/janpfeifer/must\"\n",
    "\t\"github.com/gomlx/gomlx/backends\"\n",
    "\n",
    "    _ \"github.com/gomlx/gomlx/backends/default\"\n",
    ")\n",
    "\n",
    "var (\n",
    "    flagDataDir   = flag.String(\"data\", \"~/work/ogbnmag\", \"Directory to cache downloaded and generated dataset files.\")\n",
    "    backend = backends.MustNew()\n",
    "    _ *Node = nil\n",
    ")\n",
    "\n",
    "%%\n",
    "start := time.Now()\n",
    "must.M(mag.Download(*flagDataDir))\n",
    "fmt.Printf(\"Elapsed: %s\\n\", time.Since(start))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1ff51fa4-a211-41e7-bd05-b7bc9b38906c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Elapsed: 1.63463491s\n",
      "\tfloat32(-1.439)\n",
      "\tfloat32(1.697)\n",
      "\tfloat32(0.01028)\n",
      "\tfloat32(0.05374)\n",
      "\tfloat32(0.2318)\n"
     ]
    }
   ],
   "source": [
    "%%\n",
    "start := time.Now()\n",
    "must.M(mag.Download(*flagDataDir))\n",
    "fmt.Printf(\"Elapsed: %s\\n\", time.Since(start))\n",
    "\n",
    "\n",
    "results := graph.NewExec(backend, func (x *Node) []*Node {\n",
    "    mean := ReduceAllMean(x)\n",
    "    variance := ReduceAllMean(Square(Sub(x, mean)))\n",
    "    stddev := Sqrt(variance)\n",
    "    return []*Node {\n",
    "        ReduceAllMin(x),\n",
    "        ReduceAllMax(x),\n",
    "        mean,\n",
    "        variance,\n",
    "        stddev,\n",
    "    }\n",
    "}).Call(mag.PapersEmbeddings)\n",
    "for _, t := range results {\n",
    "    fmt.Printf(\"\\t%s\\n\", t)\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2dffff58-3a1d-4587-8a74-9ee9c58c129b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Elapsed: 1.577930017s\n"
     ]
    }
   ],
   "source": [
    "%%\n",
    "start := time.Now()\n",
    "must.M(mag.Download(*flagDataDir))\n",
    "fmt.Printf(\"Elapsed: %s\\n\", time.Since(start))\n",
    "\n",
    "tensors.MutableFlatData[float32](mag.PapersEmbeddings, func (flat []float32) {\n",
    "    slices.Sort(flat)\n",
    "    numQuantiles := 20\n",
    "    fmt.Printf(\"%g\\n\", flat[0])\n",
    "    for ii := range numQuantiles-2 {\n",
    "        idx := (ii+1)*len(flat)/(numQuantiles)\n",
    "        fmt.Printf(\"%g\\n\", flat[idx])\n",
    "    }\n",
    "    fmt.Printf(\"%g\\n\", flat[numQuantiles-1])\n",
    "})\n",
    "// ZedMono NFM Light g"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13ed15ef-e36f-4bf8-a360-ff5bbc047b90",
   "metadata": {},
   "source": [
    "## FNN model, No Graph\n",
    "\n",
    "The first model will only use the paper features, and no relations. It serves as a baseline.\n",
    "\n",
    "In a quick experiment, without much hyperparameter tuning, we got 27.27% on test accuracy, which is inline with the corresponding results in [the leaderboard](https://ogb.stanford.edu/docs/leader_nodeprop/#ogbn-mag) (the Multi Layer Perceptron MLP entry in the bottom)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f99dda4b-4bc6-493d-a02c-a936e510828a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import (\n",
    "\t\"github.com/gomlx/gomlx/pkg/ml/context\"\n",
    "\t\"github.com/gomlx/gomlx/pkg/ml/layers/regularizers\"\n",
    ")\n",
    "func config() (ctx *context.Context) {\n",
    "    ctx = context.NewContext()\n",
    "    ctx.RngStateReset()\n",
    "    ctx.SetParams(map[string]any{\n",
    "        \"train_steps\": 400_000, \n",
    "        \"batch_size\": 128,\n",
    "        \"optimizer\": \"adamw\", \n",
    "        optimizers.LearningRateKey: 0.0001,\n",
    "        regularizers.ParamL2: 1e-4,\n",
    "        \"normalization\": \"layer\",\n",
    "        \"dropout\": 0.1,\n",
    "        \"hidden_layers\": 2,\n",
    "        \"num_nodes\": 256,\n",
    "        \"plots\": true,\n",
    "    })\n",
    "    return\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f2bc12d5-dd9b-47a0-a6eb-0e97ebbb6511",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading: \"checkpoint-n0000033-20240722-103831-step-00400000\"\n",
      "loading: \"checkpoint-n0000031-20240722-103650-step-00380971\"\n",
      "loading: \"checkpoint-n0000032-20240722-103751-step-00393020\"\n",
      "> restarting training from global_step=400000\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "(...collecting metrics, minimum 3 required to start plotting...)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Results on seeds_train:\n",
      "\tMean Loss+Regularization (#loss+): 3.390\n",
      "\tMean Loss (#loss): 3.170\n",
      "\tMean Accuracy (#acc): 24.04%\n",
      "Results on seeds_valid:\n",
      "\tMean Loss+Regularization (#loss+): 3.448\n",
      "\tMean Loss (#loss): 3.227\n",
      "\tMean Accuracy (#acc): 23.30%\n",
      "Results on seeds_test:\n",
      "\tMean Loss+Regularization (#loss+): 3.370\n",
      "\tMean Loss (#loss): 3.149\n",
      "\tMean Accuracy (#acc): 24.30%\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import \"github.com/gomlx/gomlx/examples/ogbnmag/fnn\"\n",
    "\n",
    "%%\n",
    "must.M(mag.Download(*flagDataDir))\n",
    "ctx := config()\n",
    "ctx.SetParam(\"checkpoint\", path.Join(*flagDataDir, \"fnn-baseline\"))\n",
    "ctx.SetParam(\"num_checkpoints\", 10)\n",
    "ctx.SetParam(\"train_steps\", 400_000)\n",
    "\n",
    "// Using KAN Kolmogorov–Arnold Networks\n",
    "ctx.SetParam(\"checkpoint\", path.Join(*flagDataDir, \"kan-baseline\"))\n",
    "ctx.SetParam(\"kan\", true)\n",
    "ctx.SetParam(\"num_nodes\", 48)\n",
    "ctx.SetParam(\"hidden_layers\", 2)\n",
    "// ctx.SetParam(\"kan_bspline_magnitude_l1\", 1e-3)\n",
    "\n",
    "err := fnn.Train(backend, ctx)\n",
    "if  err != nil {\n",
    "    fmt.Printf(\"%+v\\n\", err)\n",
    "}\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8906d875-13ed-4ebd-af77-33fbec8c66ed",
   "metadata": {},
   "source": [
    "## GNN (Graph Neural Networks)\n",
    "\n",
    "To use GNNs we need first graphs to run the model on. The issue is, generally the graphs in real life are too large (social networks, relational datases, etc.), including OGNB-MAG.\n",
    "\n",
    "So instead we use sampled sub-graphs to train. For inference we can also use sampled subgraphs, but later we show a work around where we can do inference, a layer of nodes at a time, where we don't need sampling.\n",
    "\n",
    "### Sampling Sub-Graphs\n",
    "\n",
    "We follow the same sampling strategy used in [TensorFlow GNN](https://github.com/tensorflow/gnn) library, describe it its [OGBN-MAG notebook](https://github.com/tensorflow/gnn/blob/main/examples/notebooks/ogbn_mag_e2e.ipynb).\n",
    "\n",
    "The `magSampler` variable loads the definition of the data graph: its node and edge sets. The `magStrategy` defines how to sample from those nodes and edges. See it's specification in [gomlx/examples/ogbnmag/sampling.go](https://github.com/gomlx/gomlx/blob/main/examples/ogbnmag/sampling.go), it's only some 20 lines long.\n",
    "\n",
    "The print out of the sampler and the strategy we are using:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e343e5fd-27a4-494f-a7d4-babc14dec453",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sampler: 4 node types, 8 edge types, Frozen\n",
      "\tNodeType \"papers\": 736,389 items\n",
      "\tNodeType \"authors\": 1,134,649 items\n",
      "\tNodeType \"institutions\": 8,740 items\n",
      "\tNodeType \"fields_of_study\": 59,965 items\n",
      "\tEdgeType \"citedBy\": [\"papers\"]->[\"papers\"], 5,416,271 edges\n",
      "\tEdgeType \"affiliatedWith\": [\"authors\"]->[\"institutions\"], 1,043,998 edges\n",
      "\tEdgeType \"affiliations\": [\"institutions\"]->[\"authors\"], 1,043,998 edges\n",
      "\tEdgeType \"hasTopic\": [\"papers\"]->[\"fields_of_study\"], 7,505,078 edges\n",
      "\tEdgeType \"topicHasPapers\": [\"fields_of_study\"]->[\"papers\"], 7,505,078 edges\n",
      "\tEdgeType \"writes\": [\"authors\"]->[\"papers\"], 7,145,660 edges\n",
      "\tEdgeType \"writtenBy\": [\"papers\"]->[\"authors\"], 7,145,660 edges\n",
      "\tEdgeType \"cites\": [\"papers\"]->[\"papers\"], 5,416,271 edges\n",
      "\n",
      "Sampling strategy: (13 Rules)\n",
      "> Rule \"seeds\": type=Node, nodeType=\"papers\", Shape=(Int32)[128] (size=128), NodeSet.size=629571\n",
      "  > Rule \"citations\": type=Edge, nodeType=\"papers\", Shape=(Int32)[128 8] (size=1024), SourceRule=\"seeds\", EdgeType=\"cites\"\n",
      "    > Rule \"citationsAuthors\": type=Edge, nodeType=\"authors\", Shape=(Int32)[128 8 8] (size=8192), SourceRule=\"citations\", EdgeType=\"writtenBy\"\n",
      "      > Rule \"papersByCitationAuthors\": type=Edge, nodeType=\"papers\", Shape=(Int32)[128 8 8 8] (size=65536), SourceRule=\"citationsAuthors\", EdgeType=\"writes\"\n",
      "        > Rule \"papersByCitationAuthorsTopics\": type=Edge, nodeType=\"fields_of_study\", Shape=(Int32)[128 8 8 8 8] (size=524288), SourceRule=\"papersByCitationAuthors\", EdgeType=\"hasTopic\"\n",
      "      > Rule \"citationAuthorsInstitutions\": type=Edge, nodeType=\"institutions\", Shape=(Int32)[128 8 8 8] (size=65536), SourceRule=\"citationsAuthors\", EdgeType=\"affiliatedWith\"\n",
      "    > Rule \"citationsTopics\": type=Edge, nodeType=\"fields_of_study\", Shape=(Int32)[128 8 8] (size=8192), SourceRule=\"citations\", EdgeType=\"hasTopic\"\n",
      "  > Rule \"seedsBase\": type=Edge, nodeType=\"papers\", Shape=(Int32)[128 1] (size=128), SourceRule=\"seeds\", EdgeType=Identity\n",
      "    > Rule \"seedsAuthors\": type=Edge, nodeType=\"authors\", Shape=(Int32)[128 1 8] (size=1024), SourceRule=\"seedsBase\", EdgeType=\"writtenBy\"\n",
      "      > Rule \"papersByAuthors\": type=Edge, nodeType=\"papers\", Shape=(Int32)[128 1 8 8] (size=8192), SourceRule=\"seedsAuthors\", EdgeType=\"writes\"\n",
      "        > Rule \"papersByAuthorsTopics\": type=Edge, nodeType=\"fields_of_study\", Shape=(Int32)[128 1 8 8 8] (size=65536), SourceRule=\"papersByAuthors\", EdgeType=\"hasTopic\"\n",
      "      > Rule \"authorsInstitutions\": type=Edge, nodeType=\"institutions\", Shape=(Int32)[128 1 8 8] (size=8192), SourceRule=\"seedsAuthors\", EdgeType=\"affiliatedWith\"\n",
      "    > Rule \"seedsTopics\": type=Edge, nodeType=\"fields_of_study\", Shape=(Int32)[128 1 8] (size=1024), SourceRule=\"seedsBase\", EdgeType=\"hasTopic\"\n"
     ]
    }
   ],
   "source": [
    "%%\n",
    "must.M(mag.Download(*flagDataDir))\n",
    "magSampler := must.M1(mag.NewSampler(*flagDataDir))\n",
    "magStrategy := mag.NewSamplerStrategy(magSampler, mag.BatchSize, mag.TrainSplit)\n",
    "fmt.Printf(\"%s\\n\", magSampler)\n",
    "fmt.Printf(\"\\n%s\\n\", magStrategy)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e6e73ab-9617-4ec0-a398-627d6befddb1",
   "metadata": {},
   "source": [
    "### Training Model\n",
    "\n",
    "We use the vanilla GNN trainer and model defined in the `gnn` package."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "99743b12-d5b4-40c4-be48-25749b384c95",
   "metadata": {},
   "outputs": [],
   "source": [
    "import (\n",
    "    \"github.com/gomlx/gomlx/examples/ogbnmag/gnn\"\n",
    ")\n",
    "\n",
    "func configGnn(baseDir string) *context.Context {\n",
    "    must.M(mag.Download(baseDir))\n",
    "    ctx := context.NewContext(manager)\n",
    "    ctx.RngStateReset()\n",
    "    \n",
    "    stepsPerEpoch := mag.TrainSplit.Shape().Size() / mag.BatchSize + 1\n",
    "    numEpochs := 10  // Taken from TF-GNN OGBN-MAG notebook.\n",
    "    numTrainSteps := numEpochs * stepsPerEpoch\n",
    "    \n",
    "    ctx.SetParams(map[string]any{\n",
    "        \"train_steps\": numTrainSteps,\n",
    "        \n",
    "        optimizers.ParamOptimizer: \"adam\", \n",
    "        optimizers.ParamLearningRate: 0.001,\n",
    "        optimizers.ParamCosineScheduleSteps:  numTrainSteps,\n",
    "        \n",
    "        layers.ParamL2Regularization: 1e-5,\n",
    "        layers.ParamDropoutRate: 0.2,\n",
    "\n",
    "        mag.ParamEmbedDropoutRate: 0.0,\n",
    "        \n",
    "        gnn.ParamEdgeDropoutRate: 0.0,\n",
    "        gnn.ParamNumGraphUpdates: 2,\n",
    "        gnn.ParamReadoutHiddenLayers: 2,\n",
    "        gnn.ParamPoolingType: \"mean|sum\",\n",
    "        gnn.ParamUsePathToRootStates: false,\n",
    "        \n",
    "        \"plots\": true,\n",
    "    })\n",
    "    return ctx\n",
    "}\n",
    "\n",
    "%%\n",
    "mag.BatchSize = 128  // Default is 128.\n",
    "ctx := configGnn(*flagDataDir)\n",
    "// ctx.SetParam(\"checkpoint\", path.Join(*flagDataDir, \"gnn-small_batch\"))\n",
    "// ctx.SetParam(\"checkpoint\", path.Join(*flagDataDir, \"gnn-use_path_to_root\"))\n",
    "ctx.SetParam(\"checkpoint\", path.Join(*flagDataDir, \"gnn-reg_2\"))\n",
    "ctx.SetParam(\"num_checkpoints\", 3)\n",
    "ctx.SetParams(map[string]any{\n",
    "    // \"train_steps\": 100,\n",
    "    // \"plots\": false,\n",
    "    // gnn.ParamEdgeDropoutRate: 0.1,\n",
    "    // mag.ParamEmbedDropoutRate: 0.1,\n",
    "    gnn.ParamNumGraphUpdates: 2,\n",
    "    // gnn.ParamUsePathToRootStates: true,\n",
    "    // gnn.ParamReadoutHiddenLayers: 2,\n",
    "    layers.ParamDropoutRate: 0.25,\n",
    "    layers.ParamL2Regularization: 3e-4,\n",
    "})\n",
    "\n",
    "// err := gnn.Train(ctx, *flagDataDir)\n",
    "// if  err != nil {\n",
    "//     fmt.Printf(\"%+v\\n\", err)\n",
    "// }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "ee7239da-bb04-4cb6-b41c-662ed80436b5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading: \"checkpoint-n0000171-20240229-012409-step-00049190\"\n",
      "Model in \"/home/janpf/work/ogbnmag/gnn-baseline_5\" trained for 49190 steps.\n",
      "Results on train:\n",
      "\tMean Loss+Regularization (#loss+): 1.542\n",
      "\tMean Loss (#loss): 1.457\n",
      "\tMean Accuracy (#acc): 57.49%\n",
      "\telapsed 1m37.457905276s (train)\n",
      "Results on valid:\n",
      "\tMean Loss+Regularization (#loss+): 1.927\n",
      "\tMean Loss (#loss): 1.841\n",
      "\tMean Accuracy (#acc): 48.60%\n",
      "\telapsed 11.140775268s (valid)\n",
      "Results on test:\n",
      "\tMean Loss+Regularization (#loss+): 1.922\n",
      "\tMean Loss (#loss): 1.837\n",
      "\tMean Accuracy (#acc): 48.10%\n",
      "\telapsed 7.647081029s (test)\n"
     ]
    }
   ],
   "source": [
    "%%\n",
    "ctx := configGnn(*flagDataDir)\n",
    "ctx.SetParam(\"checkpoint\", path.Join(*flagDataDir, \"gnn-baseline_5\"))\n",
    "\n",
    "_, trainEvalDS, validEvalDS, testEvalDS := must.M4(gnn.MakeDatasets(*flagDataDir))\n",
    "_, _, _ = trainEvalDS, validEvalDS, testEvalDS\n",
    "must.M(gnn.Eval(ctx, *flagDataDir, trainEvalDS, validEvalDS, testEvalDS))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 122,
   "id": "abfe4a9b-0c47-4f01-a69b-88daf0029b3b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading: \"checkpoint-n0000142-20240305-235708-step-00049190\"\n",
      "Model in \"/home/janpf/work/ogbnmag/gnn-baseline_17\" trained for 49190 steps.\n",
      "Results on valid:\n",
      "\tMean Loss+Regularization (#loss+): 1.920\n",
      "\tMean Loss (#loss): 1.834\n",
      "\tMean Accuracy (#acc): 48.64%\n",
      "\telapsed 9.335698266s (valid)\n",
      "Results on test:\n",
      "\tMean Loss+Regularization (#loss+): 1.917\n",
      "\tMean Loss (#loss): 1.830\n",
      "\tMean Accuracy (#acc): 48.38%\n",
      "\telapsed 5.863748141s (test)\n"
     ]
    }
   ],
   "source": [
    "%%\n",
    "ctx := configGnn(*flagDataDir)\n",
    "ctx.SetParam(\"checkpoint\", path.Join(*flagDataDir, \"gnn-baseline_17\"))\n",
    "\n",
    "_, trainEvalDS, validEvalDS, testEvalDS := must.M4(mag.MakeDatasets(*flagDataDir))\n",
    "_, _, _ = trainEvalDS, validEvalDS, testEvalDS\n",
    "must.M(mag.Eval(ctx, *flagDataDir, validEvalDS, testEvalDS))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "22fcf866-2682-4296-bf5b-4cc896db3bc2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<p><b>Metric: accuracy</b></p>\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div id=\"a327da6a\"></div>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<script charset=\"UTF-8\">\n",
       "(() => {\n",
       "\tconst src=\"https://cdn.plot.ly/plotly-2.29.1.min.js\";\n",
       "\tvar runJSFn = function(module) {\n",
       "\t\t\n",
       "\tif (!module) {\n",
       "\t\tmodule = window.Plotly;\n",
       "\t}\n",
       "\tlet data = JSON.parse('{\"data\":[{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"(81) train:#acc\",\"x\":[200,440,728,1074,1489,1987,2585,3303,4165,4918,5199,6440,7929,9716,9837,11860,14433,14756,17521,19675,21227,24594,25674,29513,31010,34432,37413,39351,44270,45097,49189,49190],\"y\":[0.14023835278308563,0.22449255127698067,0.2655363731811027,0.29292962985906273,0.3255216647526649,0.3434322737229002,0.36635582007430456,0.38735265760335214,0.4004361700268913,0.40937717906320337,0.42101367439097415,0.43278041714119614,0.4516249954333983,0.46329484680838223,0.4642049903823397,0.4772599119082677,0.49074528528156475,0.4952038769257161,0.5047961230742839,0.5100012548227285,0.5178065698705944,0.526537912324424,0.5282978409107154,0.5393688718190641,0.5402933108418272,0.54806526984248,0.5517534956343287,0.5534737146406045,0.5571190540860363,0.5570428116924064,0.5571937080964657,0.5571937080964657]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"(81) valid:#acc\",\"x\":[200,440,728,1074,1489,1987,2585,3303,4165,4918,5199,6440,7929,9716,9837,11860,14433,14756,17521,19675,21227,24594,25674,29513,31010,34432,37413,39351,44270,45097,49189,49190],\"y\":[0.09358960526518596,0.2007737480540699,0.25573760384716165,0.2845604895266573,0.32975230814285056,0.33781346814839935,0.3614420690824458,0.36282926678894556,0.3823887544505926,0.393409269563341,0.4010234436412399,0.40557036945698915,0.42249418147628665,0.426655774595786,0.43024707532483547,0.4317884061098352,0.4480340325837328,0.4476024599639329,0.44342545353658347,0.4554324203517317,0.4547542348063318,0.46423341913408034,0.45555572681453166,0.4620755560350807,0.46489619137163024,0.4679171997102298,0.46782471986312985,0.46554355030133016,0.47013671604062945,0.4705991152761294,0.46965890349727957,0.46965890349727957]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"(81) test:#acc\",\"x\":[200,440,728,1074,1489,1987,2585,3303,4165,4918,5199,6440,7929,9716,9837,11860,14433,14756,17521,19675,21227,24594,25674,29513,31010,34432,37413,39351,44270,45097,49189,49190],\"y\":[0.12308352607358305,0.2140012875843487,0.2641217005651065,0.2912563485061637,0.3395407615822981,0.34531104699682874,0.3660316173490069,0.36522091609241997,0.3814587853787644,0.39612294046114593,0.40172631679343807,0.4052552516750519,0.4204916664679654,0.42254226376403825,0.42599966618183555,0.43179379575097165,0.44943847015904054,0.4407115095734281,0.43770714609313527,0.4540165478432962,0.45005841817878345,0.46116979422494575,0.4514175349912969,0.4603590929683588,0.4607644435966523,0.4622427811821932,0.46550943036314646,0.4606929111328358,0.46734543026776987,0.46796537828751283,0.4674646510407974,0.4674646510407974]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"(81) train:~acc\",\"x\":[200,440,728,1074,1489,1987,2585,3303,4165,4918,5199,6440,7929,9716,9837,11860,14433,14756,17521,19675,21227,24594,25674,29513,31010,34432,37413,39351,44270,45097,49189,49190],\"y\":[0.0923939049243927,0.1741478145122528,0.23376131057739258,0.269742488861084,0.2897125780582428,0.30884259939193726,0.3290688693523407,0.340726763010025,0.3556983768939972,0.3661079704761505,0.3787955641746521,0.38522276282310486,0.3956083655357361,0.4125228524208069,0.4080027937889099,0.4265553653240204,0.43158382177352905,0.43741121888160706,0.4420183300971985,0.4537088871002197,0.4612806737422943,0.4724980294704437,0.48075345158576965,0.4767076075077057,0.48481473326683044,0.48722338676452637,0.4992702007293701,0.4959866404533386,0.5002992749214172,0.5071849226951599,0.4996989667415619,0.4996989667415619]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"(82) train:#acc\",\"x\":[200,440,728,1074,1489,1987,2585,3303,4165,4918,5199,6440,7929,9716,9837,11860,14433,14756,17521,19675,21227,24594,25674,29513,31010,34432,37413,39351,44270,45097,49189,49190],\"y\":[0.1822685606547951,0.25599654367815544,0.28944948226649575,0.31455229036915616,0.34870094079936975,0.36860655906958867,0.38850264704060383,0.40389725702105084,0.4332045154557627,0.4418818528807712,0.4471838760044538,0.4630184681314737,0.4741720949662548,0.4895190534506831,0.49146958802104923,0.5053965319241197,0.5141548768923601,0.5178065698705944,0.5278260911001301,0.5378837335264808,0.5438687614264317,0.5510418999604493,0.5577559957494865,0.5674641938716999,0.5712445458891848,0.5775678994108687,0.5818136477061364,0.5827380867288995,0.5877827917740811,0.587922569495736,0.588737410077656,0.588737410077656]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"(82) valid:#acc\",\"x\":[200,440,728,1074,1489,1987,2585,3303,4165,4918,5199,6440,7929,9716,9837,11860,14433,14756,17521,19675,21227,24594,25674,29513,31010,34432,37413,39351,44270,45097,49189,49190],\"y\":[0.15276129410132708,0.2910649054393563,0.30011251714730497,0.32112085574685184,0.32036560366220196,0.3508222999737974,0.373449035897594,0.3840996316219424,0.41504955378473773,0.4092079101095886,0.41888746743938715,0.4293685167773856,0.44230028206353367,0.45436890211008185,0.440142418964534,0.4513478937714823,0.4572203640623314,0.4586383883845312,0.47286487153007906,0.47166263351777926,0.4672852540883799,0.4841474128762774,0.4720942061375792,0.4819433098537277,0.48120347107692785,0.4886943386920267,0.49260931888592613,0.49091385502242635,0.4939965165924259,0.49185406680127625,0.49354953066477597,0.49354953066477597]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"(82) test:#acc\",\"x\":[200,440,728,1074,1489,1987,2585,3303,4165,4918,5199,6440,7929,9716,9837,11860,14433,14756,17521,19675,21227,24594,25674,29513,31010,34432,37413,39351,44270,45097,49189,49190],\"y\":[0.17172560146879992,0.3113569708385989,0.31352678890769925,0.3209184768354038,0.31989317818736734,0.35079520255609337,0.37428169484250934,0.3763084479839767,0.41140704356327046,0.4009394596914566,0.41212236820143544,0.4237583156489187,0.4326283411621641,0.4454564963399223,0.4349173800042919,0.4465056391425642,0.4541119244617182,0.45249052194854433,0.4693483392546317,0.4634111447578626,0.46429337847826607,0.48141348148501395,0.47099358592241114,0.4746894298862634,0.47461789742244687,0.4821526502777844,0.4871360785903336,0.48475166312978374,0.4869214811988841,0.48484703974820575,0.48592002670545315,0.48592002670545315]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"(82) train:~acc\",\"x\":[200,440,728,1074,1489,1987,2585,3303,4165,4918,5199,6440,7929,9716,9837,11860,14433,14756,17521,19675,21227,24594,25674,29513,31010,34432,37413,39351,44270,45097,49189,49190],\"y\":[0.12776021659374237,0.21088115870952606,0.2548511326313019,0.28552547097206116,0.3151654005050659,0.334867000579834,0.34420934319496155,0.36771607398986816,0.3882007896900177,0.3943401277065277,0.4103657603263855,0.41614168882369995,0.4300278425216675,0.43826156854629517,0.4382134675979614,0.4523680806159973,0.4559232294559479,0.46090227365493774,0.4741823673248291,0.4776058495044708,0.4874690771102905,0.4930162727832794,0.5011693239212036,0.5104051828384399,0.5119450092315674,0.5162369608879089,0.530666172504425,0.5322320461273193,0.5366436839103699,0.5359376668930054,0.5436166524887085,0.5436166524887085]}],\"layout\":{\"legend\":{},\"title\":{\"text\":\"accuracy\"},\"xaxis\":{\"showgrid\":true,\"type\":\"log\"},\"yaxis\":{\"showgrid\":true,\"type\":\"log\"}}}');\n",
       "\tmodule.newPlot('a327da6a', data);\n",
       "\n",
       "\t}\n",
       "\t\n",
       "    if (typeof requirejs === \"function\") {\n",
       "        // Use RequireJS to load module.\n",
       "\t\tlet srcWithoutExtension = src.substring(0, src.lastIndexOf(\".js\"));\n",
       "        requirejs.config({\n",
       "            paths: {\n",
       "                'plotly': srcWithoutExtension\n",
       "            }\n",
       "        });\n",
       "        require(['plotly'], function(plotly) {\n",
       "            runJSFn(plotly)\n",
       "        });\n",
       "        return\n",
       "    }\n",
       "\n",
       "\tvar currentScripts = document.head.getElementsByTagName(\"script\");\n",
       "\tfor (const idx in currentScripts) {\n",
       "\t\tlet script = currentScripts[idx];\n",
       "\t\tif (script.src == src) {\n",
       "\t\t\trunJSFn(null);\n",
       "\t\t\treturn;\n",
       "\t\t}\n",
       "\t}\n",
       "\n",
       "\tvar script = document.createElement(\"script\");\n",
       "\n",
       "\tscript.charset = \"utf-8\";\n",
       "\t\n",
       "\tscript.src = src;\n",
       "\tscript.onload = script.onreadystatechange = function () { runJSFn(null); };\n",
       "\tdocument.head.appendChild(script);\t\n",
       "})();\n",
       "</script>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<p><b>Metric: loss</b></p>\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div id=\"23dd27d6\"></div>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<script charset=\"UTF-8\">\n",
       "(() => {\n",
       "\tconst src=\"https://cdn.plot.ly/plotly-2.29.1.min.js\";\n",
       "\tvar runJSFn = function(module) {\n",
       "\t\t\n",
       "\tif (!module) {\n",
       "\t\tmodule = window.Plotly;\n",
       "\t}\n",
       "\tlet data = JSON.parse('{\"data\":[{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"(81) train:~loss\",\"x\":[200,440,728,1074,1489,1987,2585,3303,4165,4918,5199,6440,7929,9716,9837,11860,14433,14756,17521,19675,21227,24594,25674,29513,31010,34432,37413,39351,44270,45097,49189,49190],\"y\":[4.9117231369018555,3.9613046646118164,3.3920164108276367,3.084035873413086,2.880213975906372,2.7372663021087646,2.6221933364868164,2.541952133178711,2.4857726097106934,2.422252893447876,2.3696935176849365,2.3284120559692383,2.287005662918091,2.2090892791748047,2.222545862197876,2.1377522945404053,2.1233115196228027,2.0876996517181396,2.059360980987549,2.003631830215454,1.9763233661651611,1.9307252168655396,1.9058812856674194,1.9060614109039307,1.8773444890975952,1.8470205068588257,1.8261984586715698,1.8252078294754028,1.7880278825759888,1.8007527589797974,1.7778609991073608,1.7778609991073608]},{\"type\":\"scatter\",\"line\":{\"shape\":\"linear\"},\"mode\":\"lines+markers\",\"name\":\"(82) train:~loss\",\"x\":[200,440,728,1074,1489,1987,2585,3303,4165,4918,5199,6440,7929,9716,9837,11860,14433,14756,17521,19675,21227,24594,25674,29513,31010,34432,37413,39351,44270,45097,49189,49190],\"y\":[4.474365234375,3.5718984603881836,3.1486268043518066,2.8984782695770264,2.714951276779175,2.599757671356201,2.5068233013153076,2.420628309249878,2.3306710720062256,2.2848548889160156,2.220475196838379,2.1747941970825195,2.1210761070251465,2.087707757949829,2.079582452774048,2.015742063522339,1.980765700340271,1.959712266921997,1.915971279144287,1.8850421905517578,1.8437047004699707,1.814437985420227,1.788384199142456,1.741392970085144,1.7304532527923584,1.7099418640136719,1.6598256826400757,1.6511164903640747,1.6311211585998535,1.6248159408569336,1.5999643802642822,1.5999643802642822]}],\"layout\":{\"legend\":{},\"title\":{\"text\":\"loss\"},\"xaxis\":{\"showgrid\":true,\"type\":\"log\"},\"yaxis\":{\"showgrid\":true,\"type\":\"log\"}}}');\n",
       "\tmodule.newPlot('23dd27d6', data);\n",
       "\n",
       "\t}\n",
       "\t\n",
       "    if (typeof requirejs === \"function\") {\n",
       "        // Use RequireJS to load module.\n",
       "\t\tlet srcWithoutExtension = src.substring(0, src.lastIndexOf(\".js\"));\n",
       "        requirejs.config({\n",
       "            paths: {\n",
       "                'plotly': srcWithoutExtension\n",
       "            }\n",
       "        });\n",
       "        require(['plotly'], function(plotly) {\n",
       "            runJSFn(plotly)\n",
       "        });\n",
       "        return\n",
       "    }\n",
       "\n",
       "\tvar currentScripts = document.head.getElementsByTagName(\"script\");\n",
       "\tfor (const idx in currentScripts) {\n",
       "\t\tlet script = currentScripts[idx];\n",
       "\t\tif (script.src == src) {\n",
       "\t\t\trunJSFn(null);\n",
       "\t\t\treturn;\n",
       "\t\t}\n",
       "\t}\n",
       "\n",
       "\tvar script = document.createElement(\"script\");\n",
       "\n",
       "\tscript.charset = \"utf-8\";\n",
       "\t\n",
       "\tscript.src = src;\n",
       "\tscript.onload = script.onreadystatechange = function () { runJSFn(null); };\n",
       "\tdocument.head.appendChild(script);\t\n",
       "})();\n",
       "</script>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import (\n",
    "    stdplots \"github.com/gomlx/gomlx/examples/notebook/gonb/plots\"\n",
    "    \"github.com/gomlx/gomlx/examples/notebook/gonb/plotly\"\n",
    ")\n",
    "\n",
    "func filterPoints(pt *stdplots.Point) bool {\n",
    "    // Remove substrings\n",
    "    for _, s := range []string{\"Eval on \"} {\n",
    "        pt.MetricName = strings.Replace(pt.MetricName, s, \"\", -1)\n",
    "    }\n",
    "    // Replace substrings\n",
    "    for _, pair := range [][2]string{\n",
    "        {\"Accuracy\", \"acc\"}, {\"Loss+Regularization\", \"loss+\"}, {\"Loss\", \"loss\"}, {\" Mean \", \"#\"}, {\" Moving Average \", \"~\"}, \n",
    "        {\"Train\", \"train\"}, {\"Validation\", \"valid\"}, {\"Test\", \"test\"}, {\": layer-wise eval\", \":#acc\"}} {\n",
    "        pt.MetricName = strings.Replace(pt.MetricName, pair[0], pair[1], -1)\n",
    "    }\n",
    "    for _, exclude := range []string{ \"Batch\", \"Train:~loss\", \"Train:~acc\", \"loss+\" } {\n",
    "        if strings.Index(pt.MetricName, exclude) != -1 { return false }\n",
    "    }\n",
    "    return true\n",
    "}\n",
    "\n",
    "func plotVersions(versions ...int) {\n",
    "    plots := plotly.New()\n",
    "    for _, version := range versions {\n",
    "        // checkpoint := fmt.Sprintf(\"gnn-baseline_%d\", version)\n",
    "        checkpoint := fmt.Sprintf(\"small_%d\", version)\n",
    "        prefix := fmt.Sprintf(\"(%d) \", version)\n",
    "        must.M(plots.LoadCheckpointData(path.Join(*flagDataDir, checkpoint),\n",
    "            func (pt *stdplots.Point) bool {\n",
    "                if !filterPoints(pt) { return false }\n",
    "                pt.MetricName = prefix+pt.MetricName\n",
    "                return true\n",
    "            }))\n",
    "    }\n",
    "    plots.Plot()\n",
    "}\n",
    "\n",
    "%%\n",
    "plotVersions(81, 82)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b2fb28e0-2c85-4fcd-a40f-50173f76d43c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/tmp/gonb_b9950945\n"
     ]
    }
   ],
   "source": [
    "!echo $GONB_TMP_DIR"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Go (gonb)",
   "language": "go",
   "name": "gonb"
  },
  "language_info": {
   "codemirror_mode": "",
   "file_extension": ".go",
   "mimetype": "text/x-go",
   "name": "go",
   "nbconvert_exporter": "",
   "pygments_lexer": "",
   "version": "go1.24.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
